# Description:
#   Low-level implementation of asymmetric hashing.  These are used by
#   both the old, soon-to-be-deprecated API and the new, much more modular
#   API.

load(":template_sharding.bzl", "batch_size_sharder")

package(
    default_visibility = ["//visibility:public"],
    licenses = ["notice"],
)

cc_library(
    name = "asymmetric_hashing_impl_omit_frame_pointer",
    srcs = ["asymmetric_hashing_impl_omit_frame_pointer.cc"],
    hdrs = ["asymmetric_hashing_impl.h"],
    copts = [
        "-fomit-frame-pointer",
    ],
    tags = ["local"],
    deps = [
        ":asymmetric_hashing_postprocess",
        "//scann/base:restrict_allowlist",
        "//scann/data_format:datapoint",
        "//scann/data_format:dataset",
        "//scann/hashes/asymmetric_hashing2:training_options_base",
        "//scann/oss_wrappers:scann_aligned_malloc",
        "//scann/oss_wrappers:scann_down_cast",
        "//scann/oss_wrappers:tf_dependency",
        "//scann/utils:common",
        "//scann/utils:datapoint_utils",
        "//scann/utils:top_n_amortized_constant",
        "//scann/utils:types",
        "@com_google_absl//absl/base",
        "@com_google_absl//absl/flags:flag",
        "@com_google_absl//absl/log",
    ],
)

cc_library(
    name = "asymmetric_hashing_impl",
    srcs = ["asymmetric_hashing_impl.cc"],
    hdrs = ["asymmetric_hashing_impl.h"],
    tags = ["local"],
    deps = [
        ":asymmetric_hashing_impl_omit_frame_pointer",
        ":asymmetric_hashing_postprocess",
        "//scann/base:restrict_allowlist",
        "//scann/data_format:datapoint",
        "//scann/data_format:dataset",
        "//scann/distance_measures/one_to_many",
        "//scann/hashes/asymmetric_hashing2:training_options_base",
        "//scann/oss_wrappers:scann_aligned_malloc",
        "//scann/oss_wrappers:scann_random",
        "//scann/oss_wrappers:scann_status",
        "//scann/oss_wrappers:tf_dependency",
        "//scann/projection:chunking_projection",
        "//scann/utils:common",
        "//scann/utils:datapoint_utils",
        "//scann/utils:gmm_utils",
        "//scann/utils:top_n_amortized_constant",
        "//scann/utils:types",
        "@com_google_absl//absl/random",
        "@com_google_absl//absl/random:distributions",
        "@com_google_absl//absl/status",
    ],
    # The linker gets confused by explicitly instantiated templates  This
    # probably shouldn't be needed but is empirically necessary to get
    # in_memory_ah_benchmark to link in non-optimized mode.
    alwayslink = 1,
)

batch_size_sharder(
    name = "lut16_sse4_batches",
    max_batch_size = 9,
    tags = ["local"],
    template = "bazel_templates/lut16_sse4.tpl.cc",
)

cc_library(
    name = "lut16_sse4",
    srcs = [":lut16_sse4_batches"],
    hdrs = ["lut16_sse4.h"],
    # We build this library at -O3 even for fastbuild and dbg
    # builds.  With inlining, functions in this library generate
    # enormous stack frames in debug mode, causing stack overflows
    # in unit tests.  Forcing -O3 forces stack slot recycling.
    copts = [
        "-O3",
    ],
    tags = ["local"],
    textual_hdrs = ["lut16_sse4.inc"],
    deps = [
        ":lut16_args",
        "//scann/oss_wrappers:scann_bits",
        "//scann/oss_wrappers:tf_dependency",
        "//scann/utils:bits",
        "//scann/utils:common",
        "//scann/utils:types",
        "//scann/utils/intrinsics:attributes",
        "//scann/utils/intrinsics:sse4",
    ],
)

batch_size_sharder(
    name = "lut16_avx2_batches",
    max_batch_size = 9,
    tags = ["local"],
    template = "bazel_templates/lut16_avx2.tpl.cc",
)

cc_library(
    name = "lut16_avx2",
    srcs = [":lut16_avx2_batches"],
    hdrs = ["lut16_avx2.h"],
    # We build this library at -O3 even for fastbuild and dbg
    # builds.  With inlining, functions in this library generate
    # enormous stack frames in debug mode, causing stack overflows
    # in unit tests.  Forcing -O3 forces stack slot recycling.
    copts = [
        "-O3",
    ],
    tags = ["local"],
    textual_hdrs = ["lut16_avx2.inc"],
    deps = [
        ":lut16_args",
        "//scann/oss_wrappers:scann_bits",
        "//scann/oss_wrappers:tf_dependency",
        "//scann/utils:bits",
        "//scann/utils:common",
        "//scann/utils:types",
        "//scann/utils/intrinsics:attributes",
        "//scann/utils/intrinsics:avx2",
    ],
)

batch_size_sharder(
    name = "lut16_avx512_prefetch",
    max_batch_size = 9,
    tags = ["local"],
    template = "bazel_templates/lut16_avx512_prefetch.tpl.cc",
)

batch_size_sharder(
    name = "lut16_avx512_smart",
    max_batch_size = 9,
    tags = ["local"],
    template = "bazel_templates/lut16_avx512_smart.tpl.cc",
)

batch_size_sharder(
    name = "lut16_avx512_noprefetch",
    max_batch_size = 9,
    tags = ["local"],
    template = "bazel_templates/lut16_avx512_noprefetch.tpl.cc",
)

cc_library(
    name = "lut16_avx512",
    srcs = [
        "lut16_avx512_swizzle.cc",
        ":lut16_avx512_noprefetch",
        ":lut16_avx512_prefetch",
        ":lut16_avx512_smart",
    ],
    hdrs = [
        "lut16_avx512.h",
        "lut16_avx512_swizzle.h",
    ],
    # We build this library at -O3 even for fastbuild and dbg
    # builds.  With inlining, functions in this library generate
    # enormous stack frames in debug mode, causing stack overflows
    # in unit tests.  Forcing -O3 forces stack slot recycling.
    copts = [
        "-O3",
    ],
    tags = ["local"],
    textual_hdrs = ["lut16_avx512.inc"],
    deps = [
        ":lut16_args",
        "//scann/oss_wrappers:scann_aligned_malloc",
        "//scann/oss_wrappers:scann_bits",
        "//scann/oss_wrappers:scann_down_cast",
        "//scann/oss_wrappers:tf_dependency",
        "//scann/utils:bits",
        "//scann/utils:common",
        "//scann/utils:types",
        "//scann/utils/intrinsics:attributes",
        "//scann/utils/intrinsics:avx512",
        "@com_google_absl//absl/base",
        "@com_google_absl//absl/flags:flag",
        "@com_google_absl//absl/log",
    ],
)

cc_library(
    name = "write_distances_to_topn",
    srcs = ["write_distances_to_topn.cc"],
    hdrs = ["write_distances_to_topn.h"],
    tags = ["local"],
    deps = [
        ":asymmetric_hashing_postprocess",
        "//scann/base:restrict_allowlist",
        "//scann/utils:top_n_amortized_constant",
        "//scann/utils:types",
        "//scann/utils/intrinsics:sse4",
    ],
)

cc_library(
    name = "asymmetric_hashing_lut16",
    hdrs = ["asymmetric_hashing_lut16.h"],
    tags = ["local"],
    deps = [
        ":asymmetric_hashing_postprocess",
        ":lut16_interface",
        ":write_distances_to_topn",
        "//scann/base:restrict_allowlist",
        "//scann/data_format:dataset",
        "//scann/utils:top_n_amortized_constant",
        "//scann/utils:types",
    ],
)

cc_library(
    name = "asymmetric_hashing_postprocess",
    hdrs = ["asymmetric_hashing_postprocess.h"],
    tags = ["local"],
    deps = [
        "//scann/oss_wrappers:scann_aligned_malloc",
        "//scann/oss_wrappers:scann_down_cast",
        "//scann/oss_wrappers:tf_dependency",
        "//scann/utils:types",
        "@com_google_absl//absl/base",
        "@com_google_absl//absl/flags:flag",
        "@com_google_absl//absl/log",
    ],
)

cc_library(
    name = "lut16_args",
    hdrs = ["lut16_args.h"],
    tags = ["local"],
    deps = [
        "//scann/base:restrict_allowlist",
        "//scann/utils:fast_top_neighbors",
        "//scann/utils:types",
    ],
)

cc_library(
    name = "lut16_interface",
    srcs = ["lut16_interface.cc"],
    hdrs = ["lut16_interface.h"],
    tags = ["local"],
    deps = [
        ":lut16_args",
        ":lut16_avx2",
        ":lut16_avx512",
        ":lut16_sse4",
        "//scann/utils:alignment",
        "//scann/utils:common",
        "//scann/utils:fast_top_neighbors",
        "//scann/utils:types",
        "//scann/utils/intrinsics:flags",
    ],
)

cc_library(
    name = "stacked_quantizers",
    srcs = ["stacked_quantizers.cc"],
    hdrs = ["stacked_quantizers.h"],
    tags = ["local"],
    deps = [
        "//scann/data_format:datapoint",
        "//scann/data_format:dataset",
        "//scann/data_format:docid_collection",
        "//scann/distance_measures",
        "//scann/distance_measures/many_to_many",
        "//scann/distance_measures/one_to_many",
        "//scann/distance_measures/one_to_one:dot_product",
        "//scann/hashes/asymmetric_hashing2:training_options_base",
        "//scann/oss_wrappers:scann_threadpool",
        "//scann/utils:common",
        "//scann/utils:datapoint_utils",
        "//scann/utils:dataset_sampling",
        "//scann/utils:gmm_utils",
        "//scann/utils:types",
        "@com_google_absl//absl/log",
    ],
)
